# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.

from typing import Callable, Union, overload, TypeVar

from ..._C import ir
from ..core.dtype import KnownTypes as KT
from ..core.enums import GatherMaskMode
from ..core.ir_value import materialize_ir_value as _mat
from ..core.tensor import LocalTensor
from ..core.utils import require_jit, global_builder
from ..core.types import GatherMaskParams
from .utils import OverloadDispatcher, set_common_docstring

T = TypeVar("T", bound=Callable)


def check_type_gather_mask(dst: LocalTensor, src0: LocalTensor, src1_pattern: Union[LocalTensor, int]) -> None:
    """
    Check data type constraints for GatherMask operation.
    
    According to GatherMask specification:
    - T (dst and src0 data types):
        Atlas inference AI Core: half/uint16_t/int16_t/float/uint32_t/int32_t
        Atlas A2 training/Atlas 800I A2 inference/A200I A2 Box: half/bfloat16_t/uint16_t/int16_t/float/uint32_t/int32_t
        Atlas 200I/500 A2 inference: half/uint16_t/int16_t/float/uint32_t/int32_t
        Atlas A3 training/A3 inference: half/bfloat16_t/uint16_t/int16_t/float/uint32_t/int32_t
    
    - U (src1_pattern data type when LocalTensor):
        uint16_t/uint32_t
        When dst is half/uint16_t/int16_t: src1_pattern should be uint16_t
        When dst is float/uint32_t/int32_t: src1_pattern should be uint32_t
    """
    valid_dst_src0_types = [
        KT.half,       
        KT.uint16,     
        KT.int16,      
        KT.float_,     
        KT.uint32,     
        KT.int32,      
    ]    
    if dst.dtype not in valid_dst_src0_types:
        raise TypeError(f"Invalid dst data type for GatherMask: {dst.dtype}. "
                       f"Supported types: half, uint16, int16, float, uint32, int32")    
    if src0.dtype not in valid_dst_src0_types:
        raise TypeError(f"Invalid src0 data type for GatherMask: {src0.dtype}. "
                       f"Supported types: half, uint16, int16, float, uint32, int32")
    if dst.dtype != src0.dtype:
        raise TypeError(f"dst and src0 must have same data type. Got dst={dst.dtype}, src0={src0.dtype}")
    if isinstance(src1_pattern, LocalTensor):
        if dst.dtype in [KT.half, KT.uint16, KT.int16]:
            if src1_pattern.dtype != KT.uint16:
                raise TypeError(f"For dst data type {dst.dtype}, src1_pattern must be uint16. Got {src1_pattern.dtype}")
        elif dst.dtype in [KT.float_, KT.uint32, KT.int32]:
            if src1_pattern.dtype != KT.uint32:
                raise TypeError(f"For dst data type {dst.dtype}, src1_pattern must be uint32. Got {src1_pattern.dtype}")
        else:
            raise TypeError(f"Unsupported dst data type for src1_pattern validation: {dst.dtype}")    
    elif isinstance(src1_pattern, int):
        if not (1 <= src1_pattern <= 7):
            raise ValueError(f"Built-in src1_pattern must be between 1 and 7. Got {src1_pattern}")
    else:
        raise TypeError(f"src1_pattern must be either LocalTensor or int. Got {type(src1_pattern)}")


@overload
def gather_mask(dst: LocalTensor, src0: LocalTensor, src1_pattern: LocalTensor,
               reduce_mode: bool, mask: int, params: GatherMaskParams,
               rsvd_cnt: int, gather_mask_mode=GatherMaskMode.DEFAULT):
    ...


@overload
def gather_mask(dst: LocalTensor, src0: LocalTensor, src1_pattern: int,
               reduce_mode: bool, mask: int, params: GatherMaskParams,
               rsvd_cnt: int, gather_mask_mode=GatherMaskMode.DEFAULT):
    ...


@require_jit
@set_common_docstring("gather_mask")
def gather_mask(dst: LocalTensor, src0: LocalTensor, *args, **kwargs):
    builder = global_builder.get_ir_builder()
    
    dispatcher = OverloadDispatcher("gather_mask")
    
    @dispatcher.register_auto
    def _(src1_pattern: LocalTensor, reduce_mode: bool, mask: int,
          params: GatherMaskParams, rsvd_cnt: int, gather_mask_mode: GatherMaskMode):
        check_type_gather_mask(dst, src0, src1_pattern)
        rsvd_cnt_var = builder.create_memref_AllocaOp(ir.get_memref_type(builder.get_ui64_type(), 1), False)
        builder.create_asc_GatherMaskOp(
            dst.to_ir(), src0.to_ir(), src1_pattern.to_ir(),
            _mat(reduce_mode, KT.bool_).to_ir(), _mat(mask, KT.uint32).to_ir(),
            params.to_ir(), rsvd_cnt_var, gather_mask_mode
        )

    @dispatcher.register_auto
    def _(src1_pattern: int, reduce_mode: bool, mask: int,
          params: GatherMaskParams, rsvd_cnt: int, gather_mask_mode: GatherMaskMode):
        check_type_gather_mask(dst, src0, src1_pattern)
        rsvd_cnt_var = builder.create_memref_AllocaOp(ir.get_memref_type(builder.get_ui64_type(), 1), False)
        builder.create_asc_GatherMaskOp(
            dst.to_ir(), src0.to_ir(), _mat(src1_pattern, KT.uint8).to_ir(),
            _mat(reduce_mode, KT.bool_).to_ir(), _mat(mask, KT.uint32).to_ir(),
            params.to_ir(), rsvd_cnt_var, gather_mask_mode
        )
    dispatcher(*args, **kwargs)


@require_jit
def get_gather_mask_remain_count() -> int:
    builder = global_builder.get_ir_builder()
    result = builder.create_asc_GetGatherMaskRemainCountOp(builder.get_ui64_type())
    return result